{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Learning Frameworks\n",
    "\n",
    "Deep learning frameworks, like Chainer, PyTorch (which copied Chainer's design) and TensorFlow,\n",
    "are designed to make writing neural networks easy.\n",
    "Think back to the tutorial notebook:\n",
    "There were lots of places where we wrote repetitive code.\n",
    "We definitely could have benefited from a framework to help us.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What does a deep learning framework give you?\n",
    "\n",
    "At its core, a deep learning framework gives you the tools needed to handle models, losses and optimizers. It basically gives you a library of:\n",
    "\n",
    "1. functions that you can chain together to express your model\n",
    "1. loss functions that you can compose together to express your final loss function(s)\n",
    "1. gradient-based optimizers that you can use to optimize your parameters\n",
    "\n",
    "Not to mention, because gradients are so important, the library should also provide\n",
    "an automatic differentiation system to produce a gradient function,\n",
    "or it should implement the gradients of each model function.\n",
    "\n",
    "Beyond that, it can (and probably should) also provide:\n",
    "\n",
    "1. convenience functions for instantiating parameters without needing you to manually specify their shapes\n",
    "1. a syntax for specifying neural network models conveniently\n",
    "1. tooling to help you speed up the training of neural networks, whether it be by distributed training or some form of compilation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reinventing the wheel to learn the wheel\n",
    "\n",
    "To learn how a deep learning framework is structured, we are going to attempt to make one. \n",
    "After all, there is no better way to learn about a wheel than reinventing it.\n",
    "\n",
    "### Course design choices\n",
    "\n",
    "In designing this section of the course, I have made a few choices w.r.t. how to structure this section.\n",
    "\n",
    "Firstly, we are adopting a \"functional\"-first approach to building a deep learning framework.\n",
    "The main reason is that I would like you to see how the math functions that we write in NumPy\n",
    "get translated into neural network layers.\n",
    "They operate on \"data\", which are both the parameters of the model and the data we train the model on.\n",
    "\n",
    "For this reason, we are also secondly avoiding the use of the large frameworks, PyTorch and TensorFlow.\n",
    "This is not to say that they are useless; much to the contrary, learning them is _extremely_ useful.\n",
    "But if the goal is to uncover what they provide, we should stay away from them until we learn the fundamentals.\n",
    "\n",
    "Thirdly, we are going to stick with `jax`, as it is the only system availble at the moment\n",
    "that provides automatic differentiation on the NumPy API.\n",
    "Being able to express our layers in the NumPy API has some advantages\n",
    "in that we can stick with PyData community idioms,\n",
    "Coincidentally, in `jax`, there are experimental libraries that provide a pedagogical view\n",
    "on what goes into a deep learning library, \n",
    "and in this section, we will use the interplay of mixing and matching them\n",
    "to give you a broad starting view of DL libraries, \n",
    "thus helping you avoid mental lock-in as you go onto the big libraries.\n",
    "\n",
    "To get us started, we are going to use the neural network example from before to anchor us."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "X = pd.read_csv('../data/biodeg_X.csv', index_col=0)\n",
    "y = pd.read_csv('../data/biodeg_y.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layers of Functions: The Dense Transformation\n",
    "\n",
    "Our model structure was expressed as the following:\n",
    "\n",
    "```python\n",
    "def nn_model(p, x):\n",
    "    # \"a1\" is the activation from layer 1\n",
    "    a1 = np.tanh(np.dot(x, p['w1']) + p['b1'])\n",
    "    # \"a2\" is the activation from layer 2.\n",
    "    # Explain why we need logistic at the end.\n",
    "    a2 = logistic(np.dot(a1, p['w2']) + p['b2'])\n",
    "    return a2\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Firstly, notice how we keep using the line:\n",
    "\n",
    "```python\n",
    "np.dot(x, p[\"w\"] + p[\"b\"])\n",
    "```\n",
    "\n",
    "In neural network land, we call that dot product a \"Dense\" transformation, and is basically the most commonly used linear algebra operation in most of deep learning. \n",
    "This can be expressed in the following form:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as np\n",
    "\n",
    "def dense(params, x):\n",
    "    return np.dot(x, params[\"w\"]) + params[\"b\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that, our `nn_model` can be more concisely expressed as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logistic(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "def nn_model(p, x):\n",
    "    a1 = np.tanh(dense(p[\"dense1\"], x))\n",
    "    a2 = logistic(dense(p[\"dense2\"], a1))\n",
    "    return a2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should expect that a \"Dense\" layer be provided in any deep learning library.\n",
    "\n",
    "Now, if providing a `Dense` layer is all that we expected of deep learning libraries, that would still be a little too easy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters\n",
    "\n",
    "Now that we've figured out how to create a \"dense\" layer, we also need to instantiate the model parameters.\n",
    "\n",
    "Now, note how, though, we cannot directly access the parameters `w1`, `b1`, `w2` and `b2`.\n",
    "We now have a bit of indirection and nesting necessary to keep the `dense` function general.\n",
    "\n",
    "Since a `dense` layer requires both `w` and `b`, we can write a function that instantiates random numbers for them,\n",
    "with correctly specified shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy.random as npr\n",
    "\n",
    "def dense_params(params, name, input_dim, output_dim):\n",
    "    params[name] = dict()\n",
    "    params[name][\"w\"] = npr.normal(size=(input_dim, output_dim))\n",
    "    params[name][\"b\"] = npr.normal(size=(output_dim,))\n",
    "    return params\n",
    "\n",
    "params = dict()\n",
    "params = dense_params(params, \"dense1\", 41, 20)\n",
    "params = dense_params(params, \"dense2\", 20, 1)\n",
    "params[\"dense2\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing whether the params work, we can simply pass the data and params through the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_model(params, X.values).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And with that, we've really greatly reduced the amount of boilerplate we need to write when instantiating parameters!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimization Routine and Updating\n",
    "\n",
    "If you remember, our training loop generally looked something like the following:\n",
    "\n",
    "```python\n",
    "losses = []\n",
    "for i in tqdmn(range(1000)):\n",
    "    grad_p = dlogistic_loss(params, X.values, y.values)\n",
    "    for k, v in params.items():\n",
    "        params[k] = params[k] - grad_p[k] * 0.0001\n",
    "    losses.append(logistic_loss(params, X.values, y.values))\n",
    "```\n",
    "\n",
    "This design works if our parameters are one-layer nested dictionaries that have the following structure:\n",
    "\n",
    "```python\n",
    "params = {\n",
    "    \"layer1\": {\n",
    "        \"param1\": np.array...,\n",
    "        \"param2\": np.array...,\n",
    "    },\n",
    "    \"layer2\": {\n",
    "        \"param1\": np.array...,\n",
    "        \"param2\": np.array...,\n",
    "        \"param3\": np.array...,\n",
    "    },\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automating Parameter Shape Construction\n",
    "\n",
    "Now, parameters are also important keep in mind. The `dense` layer that we defined above _assumes_ the following shapes for our data and params:\n",
    "\n",
    "```\n",
    "# Note: \":\" represents the \"sample\" dimension.\n",
    "(:, 41) @ (41, 20) + (20,)   => (:, 20)\n",
    "x       @ p[\"w\"]   + p[\"b\"]  => out\n",
    "```\n",
    "\n",
    "Recall that when we initialized our parameters, we had to manually specify the parameter names and their shapes.\n",
    "\n",
    "```python\n",
    "params = dict()\n",
    "params['w1'] = noise((41, 20))\n",
    "params['b1'] = noise((20,))\n",
    "params['w2'] = noise((20, 1))\n",
    "params['b2'] = noise((1,))\n",
    "```\n",
    "\n",
    "Doing this manually is extremely tedious! Surely there has to be an easier way to specify them?\n",
    "\n",
    "If we think carefully about what we had to do here, when we specify a neural network layer (such as `dense`), we need to specify the input and output dimensions, ignoring the \"sample\" dimension (by convention, the first dimension in our X and y matrices). So in our example above, the input dimension is `41`, the output dimension is `20`. \n",
    "\n",
    "Consequently, the output dimension of the first `dense` layer (`20`) is also the input dimension of the second layer. \n",
    "\n",
    "If we were to design an API (basically a language), we would want to simplify the specification of neural net layers, perhaps using the following API sketch as a basis:\n",
    "\n",
    "```python\n",
    "model = (\n",
    "    Dense(20), Tanh,\n",
    "    Dense(1), Logistic,\n",
    ")\n",
    "```\n",
    "\n",
    "To do this, we will need a way to automatically generate parameters that are of the correct shape.\n",
    "\n",
    "## `stax`\n",
    "\n",
    "What `jax` provides is an experimental module called `stax`, which we can leverage to provide a pedagogical view on how to do this.\n",
    "In `stax`, every neural network layer, which is a function itself that takes in its data's input shape, returns a pair of functions, the `init_fun` and `apply_fun`. \n",
    "\n",
    "- the `init_fun` returns the output shape of the layer, as well as parameters that are intialized with the correct shape\n",
    "- the `apply_fun` takes in parameters and data, and returns the output, also correctly shaped.\n",
    "\n",
    "Let's study the `Dense` implementation in `stax`:\n",
    "\n",
    "```python\n",
    "def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):\n",
    "    \"\"\"Layer constructor function for a dense (fully-connected) layer.\"\"\"\n",
    "    def init_fun(rng, input_shape):\n",
    "        output_shape = input_shape[:-1] + (out_dim,)\n",
    "        k1, k2 = random.split(rng)\n",
    "        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))\n",
    "        return output_shape, (W, b)\n",
    "\n",
    "    def apply_fun(params, inputs, **kwargs):\n",
    "        W, b = params\n",
    "        return np.dot(inputs, W) + b\n",
    "\n",
    "    return init_fun, apply_fun\n",
    "```\n",
    "\n",
    "### `apply_fun`\n",
    "\n",
    "The `apply_fun` is easy to understand: it looks identical to the `dense` layer that we defined above, except that the parameters are defined as a tuple intead of a dictionary.\n",
    "\n",
    "### `init_fun`\n",
    "\n",
    "In the `init_fun`, the output shape is defined in a very general fashion. If we structure a dot product using:\n",
    "\n",
    "```python\n",
    "np.dot(inputs, W)\n",
    "```\n",
    "\n",
    "then the number of dimensions in the `input` doesn't matter, as long as the last dimension matches with the first dimension of `W`:\n",
    "\n",
    "- `(n_sample, n_input_dims) @ (n_input_dims, n_output_dims)` will execute correctly\n",
    "- `(n_sample, n_nuisance_dims, n_input_dims) @ (n_input_dims, n_output_dims)` will execute correctly.\n",
    "- `(n_sample, n_nuisance_dims1, ..., n_input_dims) @ (n_input_dims, n_output_dims)` will execute correctly.\n",
    "\n",
    "As such, the output shape is defined as the last dimension of the input shape and the output dim.\n",
    "\n",
    "`W_init` and `b_init` are nothing more than random number generators. For now, only the shapes matter:\n",
    "\n",
    "- `W` gets shape `(input_dims[-1], output_dims)`\n",
    "- `b` gets shape `(output_dims,)`\n",
    "\n",
    "For now, let's ignore what `glorot_normal()` and `normal()` exactly are; it is sufficient for us to call them random number generators.\n",
    "\n",
    "### Using `stax.Dense`\n",
    "\n",
    "Let's put `stax.Dense` to use so that you can see how it works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.stax import Dense\n",
    "from jax.random import PRNGKey\n",
    "\n",
    "k = PRNGKey(42)  # you can put any arbitrary number there\n",
    "\n",
    "init_fun, apply_fun = Dense(20)\n",
    "output_shape, params = init_fun(k, input_shape=(-1, 41))\n",
    "params[0].shape, params[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_fun(params, X.values).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chaining layers with `stax.serial`\n",
    "\n",
    "And now we can start chaining the Dense layers together by using `stax.serial`. \n",
    "`stax.serial` does a few nice magical things under the hood, and it all stems from the implementation:\n",
    "\n",
    "```python\n",
    "def serial(*layers):\n",
    "    \"\"\"Combinator for composing layers in serial.\n",
    "\n",
    "    Args:\n",
    "    *layers: a sequence of layers, each an (init_fun, apply_fun) pair.\n",
    "\n",
    "    Returns:\n",
    "    A new layer, meaning an (init_fun, apply_fun) pair, representing the serial\n",
    "    composition of the given sequence of layers.\n",
    "    \"\"\"\n",
    "    nlayers = len(layers)\n",
    "    init_funs, apply_funs = zip(*layers)\n",
    "\n",
    "    def init_fun(rng, input_shape):\n",
    "        params = []\n",
    "        for init_fun in init_funs:\n",
    "            rng, layer_rng = random.split(rng)\n",
    "            \n",
    "            # THE MONEY LINE IS HERE! LOOK HERE!\n",
    "            input_shape, param = init_fun(layer_rng, input_shape)\n",
    "            params.append(param)\n",
    "        return input_shape, params\n",
    "\n",
    "    def apply_fun(params, inputs, **kwargs):\n",
    "        rng = kwargs.pop('rng', None)\n",
    "        rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers\n",
    "        for fun, param, rng in zip(apply_funs, params, rngs):\n",
    "            inputs = fun(param, inputs, rng=rng, **kwargs)\n",
    "        return inputs\n",
    "\n",
    "    return init_fun, apply_fun\n",
    "```\n",
    "\n",
    "In here, the key thing to notice is the loop in `init_fun`:\n",
    "\n",
    "```python\n",
    "for init_fun in init_funs:\n",
    "    rng, layer_rng = random.split(rng)\n",
    "    input_shape, param = init_fun(layer_rng, input_shape)\n",
    "```\n",
    "\n",
    "Through this, the `output_shape` of one layer becomes the `input_shape` of the next,\n",
    "thus allowing the parameters to be specified correctly\n",
    "without needing to explicitly and manually specify shapes.\n",
    "(We must admit, after all, that it's simply frustrating\n",
    "to manually keep track of them all!)\n",
    "\n",
    "## Using `stax.serial`\n",
    "\n",
    "Let's put it to use to make sure we can see how it works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.stax import Tanh, Elu, serial\n",
    "init_fun, apply_fun = serial(\n",
    "    Dense(20),\n",
    "    Elu,\n",
    "    Dense(1),\n",
    "    Tanh\n",
    ")\n",
    "\n",
    "output_shape, params = init_fun(k, input_shape=(-1, 41))\n",
    "output_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params[0][0].shape, params[0][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params[2][0].shape, params[2][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_fun(params, X.values).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, by leveraging **compositionality**, we can use simple syntax to chain layers together without needing to manually specify input/output shapes. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl-workshop",
   "language": "python",
   "name": "dl-workshop"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
